{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for the figures, where an image is restored from a fraction of pixels (fig. 7 bottom, fig. 14 of supmat)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "#os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import numpy as np\n",
    "from models.resnet import ResNet\n",
    "from models.unet import UNet\n",
    "from models.skip import skip\n",
    "from models import get_net\n",
    "import torch\n",
    "import torch.optim\n",
    "from skimage.measure import compare_psnr\n",
    "\n",
    "from utils.inpainting_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "PLOT = True\n",
    "imsize=-1\n",
    "dim_div_by = 64\n",
    "dtype = torch.cuda.FloatTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Choose figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig. 7 (bottom)\n",
    "f = './data/restoration/barbara.png'\n",
    "\n",
    "# fig. 14 of supmat\n",
    "# f = './data/restoration/kate.png'\n",
    "\n",
    "\n",
    "img_pil, img_np = get_image(f, imsize)\n",
    "\n",
    "if 'barbara' in f:\n",
    "    img_np = nn.ReflectionPad2d(1)(np_to_torch(img_np))[0].numpy()\n",
    "    img_pil = np_to_pil(img_np)\n",
    "    \n",
    "    img_mask = get_bernoulli_mask(img_pil, 0.50)\n",
    "    img_mask_np = pil_to_np(img_mask)\n",
    "elif 'kate' in f:\n",
    "    img_mask = get_bernoulli_mask(img_pil, 0.98)\n",
    "\n",
    "    img_mask_np = pil_to_np(img_mask)\n",
    "    img_mask_np[1] = img_mask_np[0]\n",
    "    img_mask_np[2] = img_mask_np[0]\n",
    "else:\n",
    "    assert False\n",
    "    \n",
    "\n",
    "img_masked = img_np * img_mask_np\n",
    "\n",
    "mask_var = np_to_torch(img_mask_np).type(dtype)\n",
    "\n",
    "plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3,11);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up everything"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_every=50\n",
    "figsize=5\n",
    "pad = 'reflection' # 'zero'\n",
    "INPUT = 'noise'\n",
    "input_depth = 32\n",
    "OPTIMIZER = 'adam'\n",
    "OPT_OVER =  'net'\n",
    "if 'barbara' in f:\n",
    "    OPTIMIZER = 'adam'\n",
    "    \n",
    "    LR = 0.001\n",
    "    num_iter = 11000\n",
    "    reg_noise_std = 0.03\n",
    "    \n",
    "    NET_TYPE = 'skip'\n",
    "    net = get_net(input_depth, 'skip', pad, n_channels=1,\n",
    "                  skip_n33d=128, \n",
    "                  skip_n33u=128, \n",
    "                  skip_n11=4, \n",
    "                  num_scales=5,\n",
    "                  upsample_mode='bilinear').type(dtype)\n",
    "elif 'kate' in f:\n",
    "    OPT_OVER = 'net'\n",
    "    num_iter = 1000\n",
    "    LR = 0.01\n",
    "    reg_noise_std = 0.00\n",
    "        \n",
    "    net = skip(input_depth, \n",
    "               img_np.shape[0], \n",
    "               num_channels_down = [16, 32, 64, 128, 128],\n",
    "               num_channels_up   = [16, 32, 64, 128, 128],\n",
    "               num_channels_skip =    [0, 0, 0, 0, 0],   \n",
    "               filter_size_down = 3, filter_size_up = 3, filter_skip_size=1,\n",
    "               upsample_mode='bilinear', \n",
    "               downsample_mode='avg',\n",
    "               need_sigmoid=True, need_bias=True, pad=pad).type(dtype)\n",
    "    \n",
    "# Loss\n",
    "mse = torch.nn.MSELoss().type(dtype)\n",
    "img_var = np_to_torch(img_np).type(dtype)\n",
    "\n",
    "net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype).detach()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def closure():\n",
    "\n",
    "    global i, psrn_masked_last, last_net, net_input\n",
    "    \n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "    \n",
    "    out = net(net_input)\n",
    "\n",
    "    total_loss = mse(out * mask_var, img_var * mask_var)\n",
    "    total_loss.backward()\n",
    "    \n",
    "    psrn_masked = compare_psnr(img_masked, out.detach().cpu().numpy()[0] * img_mask_np) \n",
    "    psrn = compare_psnr(img_np, out.detach().cpu().numpy()[0]) \n",
    "\n",
    "    print ('Iteration %05d    Loss %f PSNR_masked %f PSNR %f' % (i, total_loss.item(), psrn_masked, psrn),'\\r', end='')\n",
    "    \n",
    "    \n",
    "    if  PLOT and i % show_every == 0:\n",
    "        out_np = torch_to_np(out)\n",
    "        \n",
    "        # Backtracking\n",
    "        if psrn_masked - psrn_masked_last < -5: \n",
    "            print('Falling back to previous checkpoint.')\n",
    "\n",
    "            for new_param, net_param in zip(last_net, net.parameters()):\n",
    "                net_param.data.copy_(new_param.cuda())\n",
    "\n",
    "            return total_loss*0\n",
    "        else:\n",
    "            last_net = [x.cpu() for x in net.parameters()]\n",
    "            psrn_masked_last = psrn_masked\n",
    "\n",
    "\n",
    "\n",
    "        plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n",
    "\n",
    "    i += 1\n",
    "\n",
    "    return total_loss\n",
    "\n",
    "# Init globals \n",
    "last_net = None\n",
    "psrn_masked_last = 0\n",
    "i = 0\n",
    "\n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "\n",
    "# Run\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR=LR, num_iter=num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_np = torch_to_np(net(net_input))\n",
    "q = plot_image_grid([np.clip(out_np, 0, 1), img_np], factor=13);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
